{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Week 9: Normalising flows pt 3 - improved variational posterior with IAF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_probability as tfp\n",
    "tfd = tfp.distributions\n",
    "tfb = tfp.bijectors\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Improved variational posterior"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "tf.set_random_seed(100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE():\n",
    "\n",
    "    def __init__(self, use_iaf=False):\n",
    "        self.sess = tf.Session()\n",
    "        self.lambda_l2_reg = 0.01\n",
    "        self.learning_rate = 0.001\n",
    "        self.dropout = 1.\n",
    "        self.use_iaf = use_iaf\n",
    "\n",
    "        handles = self._buildGraph()\n",
    "        self.sess.run(tf.global_variables_initializer())\n",
    "\n",
    "        (self.x_in, self.dropout_, self.z_mean, self.z_log_sigma, self.z_sample,\n",
    "         self.x_reconstructed, self.cost, self.global_step, self.train_op,\n",
    "         self.rec_loss, self.kl_loss) = handles\n",
    "\n",
    "    def _buildGraph(self):\n",
    "        x_in = tf.placeholder(tf.float32, shape=[None, 2], name=\"x\")\n",
    "        dropout = tf.placeholder_with_default(1., shape=[], name=\"dropout\")\n",
    "\n",
    "        h = tf.layers.Dense(8, activation=tf.nn.tanh, name=\"encoding/1\")(x_in)\n",
    "        h = tf.layers.Dense(8, activation=tf.nn.tanh, name=\"encoding/2\")(h)\n",
    "        \n",
    "        z_mean = tf.layers.Dense(2, activation=None, name=\"z_mean\")(h)\n",
    "        z_log_sigma = tf.layers.Dense(2, activation=None, name=\"z_log_sigma\")(h)\n",
    "        \n",
    "        z = tfd.MultivariateNormalDiag(loc=z_mean, scale_diag=tf.exp(z_log_sigma))\n",
    "        \n",
    "        if not self.use_iaf:\n",
    "            z_sample = z.sample()\n",
    "        else: \n",
    "            iaf_flow = self.build_iaf_flow(z)\n",
    "            z_sample = iaf_flow.sample()\n",
    "        \n",
    "        h = tf.layers.Dense(8, activation=tf.nn.sigmoid, name=\"decoding/1\")(z_sample)\n",
    "        h = tf.layers.Dense(8, activation=tf.nn.sigmoid ,name=\"decoding/2\")(h)\n",
    "        \n",
    "        x_reconstructed = tf.layers.Dense(2, activation=None, name=\"decoding/out\")(h)\n",
    "        \n",
    "        with tf.name_scope(\"l2_loss\"):\n",
    "            rec_loss = tf.reduce_sum(tf.square(x_reconstructed - x_in), 1)\n",
    "\n",
    "        if not self.use_iaf:\n",
    "            kl_loss = VAE.kullbackLeibler(z_mean, z_log_sigma)\n",
    "        else:\n",
    "            prior = tfd.MultivariateNormalDiag(loc=tf.zeros([2]))\n",
    "            kl_loss = iaf_flow.log_prob(z_sample) - tf.log(prior.prob(z_sample) + 1e-10)\n",
    "\n",
    "        with tf.name_scope(\"l2_regularization\"):\n",
    "            regularizers = [tf.nn.l2_loss(var) for var in self.sess.graph.get_collection(\n",
    "                \"trainable_variables\") if (\"kernel\" in var.name and \"decoding\" not in var.name)]\n",
    "            l2_reg = self.lambda_l2_reg * tf.add_n(regularizers)\n",
    "\n",
    "        with tf.name_scope(\"cost\"):\n",
    "            cost = tf.reduce_mean(rec_loss + kl_loss, name=\"vae_cost\")\n",
    "            cost += l2_reg\n",
    "\n",
    "        global_step = tf.Variable(0, trainable=False)\n",
    "        with tf.name_scope(\"Adam_optimizer\"):\n",
    "            optimizer = tf.train.AdamOptimizer(self.learning_rate)\n",
    "            tvars = tf.trainable_variables()\n",
    "            self.grads_and_vars = optimizer.compute_gradients(cost, tvars)\n",
    "            clipped = [(tf.clip_by_value(grad, -0.1, 0.1), tvar)\n",
    "                    for grad, tvar in self.grads_and_vars]\n",
    "            train_op = optimizer.apply_gradients(clipped, global_step=global_step,\n",
    "                                                 name=\"minimize_cost\")\n",
    "\n",
    "        return (x_in, dropout, z_mean, z_log_sigma, z_sample, x_reconstructed,\n",
    "                cost, global_step, train_op, tf.reduce_mean(rec_loss), tf.reduce_mean(kl_loss))\n",
    "\n",
    "    @staticmethod\n",
    "    def kullbackLeibler(mu, log_sigma):\n",
    "        with tf.name_scope(\"KL_divergence\"):\n",
    "            return -0.5 * tf.reduce_sum(1 + 2 * log_sigma - mu**2 -\n",
    "                                        tf.exp(2 * log_sigma), 1)\n",
    "        \n",
    "    def build_iaf_flow(self, base_dist):\n",
    "        bijectors = [\n",
    "            tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(\n",
    "            hidden_layers=[64, 64])),\n",
    "            tfb.Permute(permutation=[1, 0]),\n",
    "            tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(\n",
    "            hidden_layers=[64, 64])),\n",
    "            tfb.Permute(permutation=[1, 0]),\n",
    "            tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(\n",
    "            hidden_layers=[64, 64])),\n",
    "            tfb.Permute(permutation=[1, 0]),\n",
    "            tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(\n",
    "            hidden_layers=[64, 64]))\n",
    "        ]\n",
    "\n",
    "        maf_bijector = tfb.Chain(list(reversed(bijectors)), name='maf_bijector')\n",
    "        return tfd.TransformedDistribution(distribution=base_dist, bijector=tfb.Invert(maf_bijector))\n",
    "\n",
    "    def encode(self, x):\n",
    "        # Encodes data points to factorised Gaussian, before passing through IAF flow (if used)\n",
    "        return self.sess.run([self.z_mean, self.z_log_sigma], feed_dict={self.x_in: x})\n",
    "    \n",
    "    def posterior_sample(self, x):\n",
    "        # Samples from the full posterior (after IAF if used)\n",
    "        return self.sess.run(self.z_sample, feed_dict={self.x_in: x})\n",
    "\n",
    "    def decode(self, zs):\n",
    "        return self.sess.run(self.x_reconstructed, feed_dict={self.z_sample: zs})\n",
    "    \n",
    "    @staticmethod\n",
    "    def plot_posterior_distribution(X):\n",
    "        X1 = X[:64, :]\n",
    "        X2 = X[64:128, :]\n",
    "        X3 = X[128:192, :]\n",
    "        X4 = X[192:, :]\n",
    "        x1_posterior_samples = model.posterior_sample(X1)\n",
    "        x2_posterior_samples = model.posterior_sample(X2)\n",
    "        x3_posterior_samples = model.posterior_sample(X3)\n",
    "        x4_posterior_samples = model.posterior_sample(X4)\n",
    "        plt.close()\n",
    "        plt.figure()\n",
    "        plt.scatter(x1_posterior_samples[:, 0], x1_posterior_samples[:, 1], color='red', s=5)\n",
    "        plt.scatter(x2_posterior_samples[:, 0], x2_posterior_samples[:, 1], color='blue', s=5)\n",
    "        plt.scatter(x3_posterior_samples[:, 0], x3_posterior_samples[:, 1], color='green', s=5)\n",
    "        plt.scatter(x4_posterior_samples[:, 0], x4_posterior_samples[:, 1], color='purple', s=5)\n",
    "        plt.title(\"Posterior distributions\")\n",
    "        display.display(plt.gcf())\n",
    "        display.clear_output(wait=True)\n",
    "\n",
    "    def train(self, x, max_iter=np.inf):\n",
    "        losses = []\n",
    "        iterations = []\n",
    "        while True:  \n",
    "            feed_dict = {self.x_in: x, self.dropout_: self.dropout}\n",
    "            x_reconstructed, cost, rec_loss, kl_loss, _, i = self.sess.run(\n",
    "                [self.x_reconstructed, self.cost, self.rec_loss, \n",
    "                 self.kl_loss, self.train_op, self.global_step], feed_dict\n",
    "            )\n",
    "\n",
    "            if i%500 == 1:\n",
    "                print(\"Iteration {}, cost: \".format(i), cost)\n",
    "                print(\"   rec_loss: {}, kl_loss: {}\".format(rec_loss, kl_loss))\n",
    "                losses.append(cost)\n",
    "                iterations.append(i)\n",
    "                VAE.plot_posterior_distribution(x)\n",
    "\n",
    "            if i >= max_iter:\n",
    "                print(\"Finished training. Final cost at iteration {}: {}\".format(i, cost))\n",
    "                print(\"   rec_loss: {}, kl_loss: {}\".format(rec_loss, kl_loss))\n",
    "                losses.append(cost)\n",
    "                iterations.append(i)\n",
    "                break\n",
    "        return losses, iterations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1 = np.array([5., 5.])\n",
    "x2 = np.array([-5., 5.])\n",
    "x3 = np.array([-5., -5.])\n",
    "x4 = np.array([5., -5.])\n",
    "\n",
    "X1 = np.vstack((x1,) * 64)\n",
    "X2 = np.vstack((x2,) * 64)\n",
    "X3 = np.vstack((x3,) * 64)\n",
    "X4 = np.vstack((x4,) * 64)\n",
    "\n",
    "X_train = np.vstack((X1, X2, X3, X4))\n",
    "X_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = VAE(use_iaf=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "start_time = time.time()\n",
    "losses, iterations = model.train(X_train, max_iter=10000)\n",
    "end_time = time.time()\n",
    "\n",
    "print(\"Training time: {}\".format(end_time - start_time))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(iterations, losses)\n",
    "plt.title(\"Training curve\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the posterior distributions before passing through the IAF flow\n",
    "\n",
    "num_samples = 256\n",
    "\n",
    "x1_mean, x1_log_sigma = model.encode(np.expand_dims(x1, 0))\n",
    "x2_mean, x2_log_sigma = model.encode(np.expand_dims(x2, 0))\n",
    "x3_mean, x3_log_sigma = model.encode(np.expand_dims(x3, 0))\n",
    "x4_mean, x4_log_sigma = model.encode(np.expand_dims(x4, 0))\n",
    "x1_samples = np.random.normal(loc=np.vstack((x1_mean,) * num_samples), scale=np.vstack((np.exp(x1_log_sigma),) * num_samples))\n",
    "x2_samples = np.random.normal(loc=np.vstack((x2_mean,) * num_samples), scale=np.vstack((np.exp(x2_log_sigma),) * num_samples))\n",
    "x3_samples = np.random.normal(loc=np.vstack((x3_mean,) * num_samples), scale=np.vstack((np.exp(x3_log_sigma),) * num_samples))\n",
    "x4_samples = np.random.normal(loc=np.vstack((x4_mean,) * num_samples), scale=np.vstack((np.exp(x4_log_sigma),) * num_samples))\n",
    "plt.scatter(x1_samples[:, 0], x1_samples[:, 1], color='red', s=5)\n",
    "plt.scatter(x2_samples[:, 0], x2_samples[:, 1], color='blue', s=5)\n",
    "plt.scatter(x3_samples[:, 0], x3_samples[:, 1], color='green', s=5)\n",
    "plt.scatter(x4_samples[:, 0], x4_samples[:, 1], color='purple', s=5)\n",
    "plt.title(\"Posterior distributions before IAF flow\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "num_samples = 256\n",
    "\n",
    "x1_posterior_samples = model.posterior_sample(np.stack([x1] * num_samples))\n",
    "x2_posterior_samples = model.posterior_sample(np.stack([x2] * num_samples))\n",
    "x3_posterior_samples = model.posterior_sample(np.stack([x3] * num_samples))\n",
    "x4_posterior_samples = model.posterior_sample(np.stack([x4] * num_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the posterior distributions after passing through the IAF flow\n",
    "\n",
    "plt.scatter(x1_posterior_samples[:, 0], x1_posterior_samples[:, 1], color='red', s=5)\n",
    "plt.scatter(x2_posterior_samples[:, 0], x2_posterior_samples[:, 1], color='blue', s=5)\n",
    "plt.scatter(x3_posterior_samples[:, 0], x3_posterior_samples[:, 1], color='green', s=5)\n",
    "plt.scatter(x4_posterior_samples[:, 0], x4_posterior_samples[:, 1], color='purple', s=5)\n",
    "plt.title(\"Posterior distributions after IAF flow\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x1_decoded = model.decode(x1_posterior_samples)\n",
    "x2_decoded = model.decode(x2_posterior_samples)\n",
    "x3_decoded = model.decode(x3_posterior_samples)\n",
    "x4_decoded = model.decode(x4_posterior_samples)\n",
    "plt.scatter(x1_decoded[:, 0], x1_decoded[:, 1], color='red', s=5)\n",
    "plt.scatter(x2_decoded[:, 0], x2_decoded[:, 1], color='blue', s=5)\n",
    "plt.scatter(x3_decoded[:, 0], x3_decoded[:, 1], color='green', s=5)\n",
    "plt.scatter(x4_decoded[:, 0], x4_decoded[:, 1], color='purple', s=5)\n",
    "plt.title(\"Reconstructions of data points\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "model.sess.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
